import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

class PwA(nn.Module):
    def __init__(self, in_features):
        super(PwA, self).__init__()
        self.fc = nn.Linear(in_features, in_features)
        nn.init.kaiming_normal_(self.fc.weight, nonlinearity='relu')

    def forward(self, x):
        weights = torch.sigmoid(self.fc(x))
        return x * weights

class RwA(nn.Module):
    def __init__(self, in_features, heads=4):
        super(RwA, self).__init__()
        self.attn_heads = nn.ModuleList([
            nn.Linear(in_features, in_features) for _ in range(heads)
        ])
        for head in self.attn_heads:
            nn.init.kaiming_normal_(head.weight, nonlinearity='relu')

    def forward(self, x):
        out = [F.relu(head(x)) for head in self.attn_heads]
        return torch.mean(torch.stack(out), dim=0)

class SMART(nn.Module):
    def __init__(self, num_classes):
        super(SMART, self).__init__()
        base_model = models.resnet101(weights=models.ResNet101_Weights.IMAGENET1K_V1)
        self.backbone = nn.Sequential(*list(base_model.children())[:-1])
        self.pwa = PwA(2048)
        self.rwa = RwA(2048, heads=4)
        self.fc = nn.Linear(2048, num_classes)

    def forward(self, x):
        x = self.backbone(x)
        x = torch.flatten(x, 1)
        x = self.pwa(x)
        x = self.rwa(x)
        x = self.fc(x)
        return F.log_softmax(x, dim=1)
